# =============================================================================
# ENTITY-LEVEL KRONECKER SDM PANEL (Province × Sector) - FIXED FOR STRING sector_id
# =============================================================================
import pandas as pd
import numpy as np
import geopandas as gpd
from libpysal.weights import Queen, W
from scipy.sparse import kron as kronecker_product
from scipy.sparse import csr_matrix
from spreg import Panel_FE_Lag
from scipy import stats
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from esda.moran import Moran
from datetime import datetime
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

print("=== ENTITY-LEVEL KRONECKER SDM PANEL (TRUE SDM - Province×Sector) ===")

# --------------------------------------------------
# 1. Load the dataset
# --------------------------------------------------
df = pd.read_csv("dataset.csv")
print(f"Loaded dataset: {len(df):,} observations")

# --------------------------------------------------
# 2. Define variables
# --------------------------------------------------
exog_vars = ['spe', 'div', 'den', 'ln_fdi', 'ln_productivity_lag', 'skilled_share']

# --------------------------------------------------
# 3. Province base weights (Queen)
# --------------------------------------------------
gdf = gpd.read_file("province_vn.geojson")
gdf['norm'] = gdf['ten_tinh'].str.lower().str.strip()
df['prov_norm'] = df['province_name'].str.lower().str.strip()
common = set(df['prov_norm']) & set(gdf['norm'])
df = df[df['prov_norm'].isin(common)].copy()
gdf = gdf[gdf['norm'].isin(common)].copy().reset_index(drop=True)

w_prov = Queen.from_dataframe(gdf, use_index=False)
w_prov.transform = 'r'
print(f"Province Queen weights: {w_prov.n} provinces, avg neighbors = {w_prov.mean_neighbors:.2f}")

# --------------------------------------------------
# 4. Build Kronecker entity weights (Geo ⊗ Sector)
# --------------------------------------------------
n_prov = w_prov.n
n_sec = df['sector_id'].nunique()   # 18
n_entity = n_prov * n_sec

sic_df = pd.read_csv("sector_sic.csv")
sic_values = sic_df['SIC'].values
max_dist = sic_values.max() - sic_values.min()
S = 1 - np.abs(sic_values[:, None] - sic_values[None, :]) / max_dist
np.fill_diagonal(S, 1.0)
S_sparse = csr_matrix(S)

W_entity_sparse = kronecker_product(w_prov.sparse, S_sparse).tocsr()
w_entity = W.from_sparse(W_entity_sparse)
w_entity.transform = 'r'
print(f"Kronecker entity weights: {w_entity.n} entities, avg neighbors = {w_entity.mean_neighbors:.2f}")

# --------------------------------------------------
# 5. SAFE entity_idx (handles string sector_id like 'BBL')
# --------------------------------------------------
df['prov_sec_key'] = df['province_id'].astype(str) + '_' + df['sector_id'].astype(str)
entity_list = sorted(df['prov_sec_key'].unique())          # consistent order
df['entity_idx'] = df['prov_sec_key'].map({k: i for i, k in enumerate(entity_list)})

# CRITICAL: sort by entity then year (required by Panel_FE_Lag)
df = df.sort_values(['entity_idx', 'year']).reset_index(drop=True)
print(f"Entity mapping complete: {n_entity} entities")
print(f"Data sorted by entity_idx + year: {len(df):,} observations")

# --------------------------------------------------
# 6. Create WX at entity level
# --------------------------------------------------
W_matrix = w_entity.full()[0]
wx_array = np.zeros((len(df), len(exog_vars)))
for i, var in enumerate(exog_vars):
    entity_year_avg = df.groupby(['year', 'entity_idx'])[var].transform('mean')
    lagged_values = []
    for year in sorted(df['year'].unique()):
        vec = np.zeros(n_entity)
        year_data = df[df['year'] == year]
        idx = year_data['entity_idx'].values
        vals = entity_year_avg[year_data.index]
        vec[idx] = vals
        lagged = W_matrix @ vec
        lagged_values.append(lagged[idx])
    wx_array[:, i] = np.concatenate(lagged_values)

for i, var in enumerate(exog_vars):
    df[f'W_{var}'] = wx_array[:, i]
print("WX variables created at entity level.")

# --------------------------------------------------
# 7. Year demeaning
# --------------------------------------------------
for col in ['ln_wage'] + exog_vars:
    df[f'{col}_demean_year'] = df[col] - df.groupby('year')[col].transform('mean')
print("Year demeaning applied.")

# --------------------------------------------------
# 8. Joint orthogonalization of WX
# --------------------------------------------------
print("Joint orthogonalizing WX lags...")
X_full = df[[f'{v}_demean_year' for v in exog_vars]]
X_full_with_const = sm.add_constant(X_full)
for var in exog_vars:
    model = sm.OLS(df[f'W_{var}'], X_full_with_const).fit()
    df[f'W_{var}_joint_orth'] = model.resid
    max_corr = max(abs(df[f'W_{var}_joint_orth'].corr(df[col])) for col in X_full.columns)
    print(f"Joint orth {var}: max corr with X = {max_corr:.4f}")
print("Orthogonalization complete.")

# --------------------------------------------------
# 9. Run TRUE Entity-level SDM
# --------------------------------------------------
y = df['ln_wage_demean_year'].values.reshape(-1, 1)
x_vars_sdm = [f'{v}_demean_year' for v in exog_vars] + [f'W_{v}_joint_orth' for v in exog_vars]
X = df[x_vars_sdm].values

sdm = Panel_FE_Lag(y=y, x=X, w=w_entity,
                   name_y='ln_wage',
                   name_x=x_vars_sdm,
                   name_ds='Vietnam Entity (Prov×Sec) SDM 2018–2022 (Kronecker)')

print("\n" + "="*110)
print("ENTITY-LEVEL KRONECKER SDM PANEL - TRUE SDM")
print("="*110)
print(sdm.summary)

# --------------------------------------------------
# 10. Diagnostics
# --------------------------------------------------
x_sar = df[[f'{v}_demean_year' for v in exog_vars]].values
sar = Panel_FE_Lag(y=y, x=x_sar, w=w_entity)
lr_stat = 2 * (sdm.logll - sar.logll)
p_lr = stats.chi2.sf(lr_stat, len(exog_vars))
print(f"\nLR common factor test: stat = {lr_stat:.3f}, p = {p_lr:.4f} → {'reject' if p_lr < 0.05 else 'fail to reject'}")

df['sdm_resid'] = sdm.u.flatten()
print("\nMoran's I by year (entity-level residuals):")
moran_res_sdm = []
for year in sorted(df['year'].unique()):
    year_df = df[df['year'] == year]
    resid_vec = year_df['sdm_resid'].values
    mi = Moran(resid_vec, w_entity, permutations=999)
    moran_res_sdm.append((year, mi.I, mi.p_sim))
    sig = "***" if mi.p_sim < 0.01 else "**" if mi.p_sim < 0.05 else "*" if mi.p_sim < 0.10 else ""
    print(f"Year {year}: I = {mi.I:.4f}{sig} (p={mi.p_sim:.4f})")

print("\nVIF CHECK")
X_vif = sm.add_constant(df[x_vars_sdm])
vif_data = pd.DataFrame({'feature': ['const'] + x_vars_sdm,
                         'VIF': [variance_inflation_factor(X_vif.values, i) for i in range(len(X_vif.columns))]})
print(vif_data.round(2))

# --------------------------------------------------
# 11. EXACT LeSage & Pace (2009) EFFECTS
# --------------------------------------------------
print("\n" + "="*110)
print("LeSAGE & PACE (2009) EFFECTS – EXACT (full matrix) - Kronecker")
print("="*110)

rho = float(sdm.rho)
n = len(exog_vars)
betas = sdm.betas.flatten()[:n]
thetas = sdm.betas.flatten()[n:2*n]

W_mat = w_entity.full()[0]
I = np.eye(n_entity)
S_inv = np.linalg.inv(I - rho * W_mat)

print(f"{'Variable':20} {'Direct':>12} {'Indirect':>12} {'Total':>12}")
for i, var in enumerate(exog_vars):
    M = betas[i] * I + thetas[i] * W_mat
    effects = S_inv @ M
    direct = np.mean(np.diag(effects))
    total = np.mean(effects)
    indirect = total - direct
    print(f"{var:20} {direct:12.4f} {indirect:12.4f} {total:12.4f}")

# --------------------------------------------------
# 12. SAVE
# --------------------------------------------------
output_file = "SDM_Kronecker_results.txt"
with open(output_file, "w", encoding="utf-8") as f:
    f.write("="*110 + "\n")
    f.write("ENTITY-LEVEL KRONECKER SDM PANEL – TRUE SDM (Exact Effects)\n")
    f.write(f"Date: {datetime.now().strftime('%B %d, %Y at %H:%M')}\n")
    f.write("="*110 + "\n\n")
    f.write(str(sdm.summary) + "\n\n")
    f.write("LeSage & Pace Effects (EXACT):\n")
    for i, var in enumerate(exog_vars):
        M = betas[i] * I + thetas[i] * W_mat
        effects = S_inv @ M
        direct = np.mean(np.diag(effects))
        total = np.mean(effects)
        indirect = total - direct
        f.write(f"{var:20} Direct: {direct:8.4f}  Indirect: {indirect:8.4f}  Total: {total:8.4f}\n")
    f.write(f"\nLR common factor test: stat = {lr_stat:.3f}, p = {p_lr:.4f}\n")
    f.write("Moran's I on residuals:\n")
    for year, mi_val, p_val in moran_res_sdm:
        f.write(f"Year {year}: I = {mi_val:.4f} (p={p_val:.4f})\n")
    f.write("\nVIF table:\n")
    f.write(vif_data.round(2).to_string())

print(f"\n✅ Kronecker entity-level results successfully saved to: {output_file}")